{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using the Apache MXNet Module API with SageMaker Training and Batch Transformation \n",
    "\n",
    "The SageMaker Python SDK makes it easy to train MXNet models and use them for batch transformation. In this example, we train a simple neural network using the Apache MXNet [Module API](https://mxnet.incubator.apache.org/api/python/module.html) and the MNIST dataset. The MNIST dataset is widely used for handwritten digit classification, and consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup\n",
    "\n",
    "First, we define a few variables that will be needed later in the example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "isConfigCell": true
   },
   "outputs": [],
   "source": [
    "from sagemaker import get_execution_role\n",
    "from sagemaker.session import Session\n",
    "\n",
    "sagemaker_session = Session()\n",
    "region = sagemaker_session.boto_session.region_name\n",
    "sample_data_bucket = 'sagemaker-sample-data-{}'.format(region)\n",
    "\n",
    "# S3 bucket for saving files. Feel free to redefine this variable to the bucket of your choice.\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "\n",
    "# Bucket location where your custom code will be saved in the tar.gz format.\n",
    "custom_code_upload_location = 's3://{}/mxnet-mnist-example/code'.format(bucket)\n",
    "\n",
    "# Bucket location where results of model training are saved.\n",
    "model_artifacts_location = 's3://{}/mxnet-mnist-example/artifacts'.format(bucket)\n",
    "\n",
    "# IAM execution role that gives SageMaker access to resources in your AWS account.\n",
    "# We can use the SageMaker Python SDK to get the role from our notebook environment. \n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and inference script\n",
    "\n",
    "The `mnist.py` script provides all the code we need for training and and inference. The script we will use is adaptated from the Apache MXNet [MNIST tutorial](https://mxnet.incubator.apache.org/tutorials/python/mnist.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat mnist.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SageMaker's MXNet estimator class\n",
    "\n",
    "The SageMaker ```MXNet``` estimator allows us to run single machine or distributed training in SageMaker, using CPU or GPU-based instances.\n",
    "\n",
    "When we create the estimator, we pass in the filename of our training script, the name of our IAM execution role, and the S3 locations we defined in the setup section. We also provide a few other parameters. ``train_instance_count`` and ``train_instance_type`` determine the number and type of SageMaker instances that will be used for the training job. The ``hyperparameters`` parameter is a ``dict`` of values that will be passed to your training script -- you can see how to access these values in the ``mnist.py`` script above.\n",
    "\n",
    "For this example, we will choose one ``ml.m4.xlarge`` instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.mxnet import MXNet\n",
    "\n",
    "mnist_estimator = MXNet(entry_point='mnist.py',\n",
    "                        role=role,\n",
    "                        output_path=model_artifacts_location,\n",
    "                        code_location=custom_code_upload_location,\n",
    "                        train_instance_count=1,\n",
    "                        train_instance_type='ml.m4.xlarge',\n",
    "                        framework_version='1.4.0',\n",
    "                        hyperparameters={'learning-rate': 0.1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running a training job\n",
    "\n",
    "After we've constructed our MXNet object, we can fit it using data stored in S3. Below we run SageMaker training on two input channels: train and test.\n",
    "\n",
    "During training, SageMaker makes this data stored in S3 available in the local filesystem where the `mnist.py` script is running. The script then simply loads the train and test data from disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "train_data_location = 's3://{}/mxnet/mnist/train'.format(sample_data_bucket)\n",
    "test_data_location = 's3://{}/mxnet/mnist/test'.format(sample_data_bucket)\n",
    "\n",
    "mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SageMaker's transformer class\n",
    "\n",
    "After training, we use our MXNet estimator object to create a `Transformer` by invoking the `transformer()` method. This method takes arguments for configuring our options with the batch transform job; these do not need to be the same values as the one we used for the training job. The method also creates a SageMaker Model to be used for the batch transform jobs.\n",
    "\n",
    "The `Transformer` class is responsible for running batch transform jobs, which will deploy the trained model to an endpoint and send requests for performing inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = mnist_estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running a batch transform job\n",
    "\n",
    "Now we can perform some inference with the model we've trained by running a batch transform job. The request handling behavior of the Endpoint deployed during the transform job is determined by the `mnist.py` script.\n",
    "\n",
    "For demonstration purposes, we're going to use input data that contains 1000 MNIST images, located in the public SageMaker sample data S3 bucket. To create the batch transform job, we simply call `transform()` on our transformer with information about the input data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_file_path = 'batch-transform/mnist-1000-samples'\n",
    "\n",
    "transformer.transform('s3://{}/{}'.format(sample_data_bucket, input_file_path), content_type='text/csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we wait for the batch transform job to complete. We have a convenience method, `wait()`, that will block until the batch transform job has completed. We can call that here to see if the batch transform job is still running; the cell will finish running when the batch transform job has completed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading the results\n",
    "\n",
    "The batch transform job uploads its predictions to S3. Since we did not specify `output_path` when creating the Transformer, one was generated based on the batch transform job name:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(transformer.output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output here will be a list of predictions, where each prediction is a list of probabilities, one for each possible label. Since we read the output as a string, we use `ast.literal_eval()` to turn it into a list and find the maximum element of the list gives us the predicted label. Here we define a convenience method to take the output and produce the predicted label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "\n",
    "def predicted_label(transform_output):\n",
    "    output = ast.literal_eval(transform_output)\n",
    "    probabilities = output[0]\n",
    "    return probabilities.index(max(probabilities))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's download the first ten results from S3:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from urllib.parse import urlparse\n",
    "\n",
    "import boto3\n",
    "\n",
    "parsed_url = urlparse(transformer.output_path)\n",
    "bucket_name = parsed_url.netloc\n",
    "prefix = parsed_url.path[1:]\n",
    "\n",
    "s3 = boto3.resource('s3')\n",
    "\n",
    "predictions = []\n",
    "for i in range(10):\n",
    "    file_key = '{}/data-{}.csv.out'.format(prefix, i)\n",
    "\n",
    "    output_obj = s3.Object(bucket_name, file_key)\n",
    "    output = output_obj.get()[\"Body\"].read().decode('utf-8')\n",
    "    \n",
    "    predictions.append(predicted_label(output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demonstration purposes, we're also going to download the corresponding original input data so that we can see how the model did with its predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "tmp_dir = '/tmp/data'\n",
    "\n",
    "if not os.path.exists(tmp_dir):\n",
    "    os.makedirs(tmp_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we'll print out the images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import genfromtxt\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.rcParams['figure.figsize'] = (2,10)\n",
    "\n",
    "def show_digit(img, caption='', subplot=None):\n",
    "    if subplot == None:\n",
    "        _,(subplot) = plt.subplots(1,1)\n",
    "    imgr = img.reshape((28,28))\n",
    "    subplot.axis('off')\n",
    "    subplot.imshow(imgr, cmap='gray')\n",
    "    plt.title(caption)\n",
    "\n",
    "for i in range(10):\n",
    "    input_file_name = 'data-{}.csv'.format(i)\n",
    "    input_file_key = '{}/{}'.format(input_file_path, input_file_name)\n",
    "    \n",
    "    s3.Bucket(sample_data_bucket).download_file(input_file_key, os.path.join(tmp_dir, input_file_name))\n",
    "    input_data = genfromtxt(os.path.join(tmp_dir, input_file_name), delimiter=',')\n",
    "\n",
    "    show_digit(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "Here, we can see the original labels are:\n",
    "\n",
    "```\n",
    "7, 2, 1, 0, 4, 1, 4, 9, 5, 9\n",
    "```\n",
    "\n",
    "Now let's print out the predictions to compare:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(predictions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
